import json
from datetime import datetime, timezone
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union

from pydantic import BaseModel, Field, field_serializer, field_validator

from letta.schemas.letta_message_content import (
    LettaAssistantMessageContentUnion,
    LettaUserMessageContentUnion,
    get_letta_assistant_message_content_union_str_json_schema,
    get_letta_user_message_content_union_str_json_schema,
)

# ---------------------------
# Letta API Messaging Schemas
# ---------------------------


class MessageType(str, Enum):
    system_message = "system_message"
    user_message = "user_message"
    assistant_message = "assistant_message"
    reasoning_message = "reasoning_message"
    hidden_reasoning_message = "hidden_reasoning_message"
    tool_call_message = "tool_call_message"
    tool_return_message = "tool_return_message"


class LettaMessage(BaseModel):
    """
    Base class for simplified Letta message response type. This is intended to be used for developers
    who want the internal monologue, tool calls, and tool returns in a simplified format that does not
    include additional information other than the content and timestamp.

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        message_type (MessageType): The type of the message
        otid (Optional[str]): The offline threading id associated with this message
        sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
    """

    id: str
    date: datetime
    name: Optional[str] = None
    message_type: MessageType = Field(..., description="The type of the message.")
    otid: Optional[str] = None
    sender_id: Optional[str] = None
    step_id: Optional[str] = None

    @field_serializer("date")
    def serialize_datetime(self, dt: datetime, _info):
        """
        Remove microseconds since it seems like we're inconsistent with getting them
        TODO: figure out why we don't always get microseconds (get_utc_time() does)
        """
        if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
            dt = dt.replace(tzinfo=timezone.utc)
        return dt.isoformat(timespec="seconds")


class SystemMessage(LettaMessage):
    """
    A message generated by the system. Never streamed back on a response, only used for cursor pagination.

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        content (str): The message content sent by the system
    """

    message_type: Literal[MessageType.system_message] = Field(MessageType.system_message, description="The type of the message.")
    content: str = Field(..., description="The message content sent by the system")


class UserMessage(LettaMessage):
    """
    A message sent by the user. Never streamed back on a response, only used for cursor pagination.

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
    """

    message_type: Literal[MessageType.user_message] = Field(MessageType.user_message, description="The type of the message.")
    content: Union[str, List[LettaUserMessageContentUnion]] = Field(
        ...,
        description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
        json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
    )


class ReasoningMessage(LettaMessage):
    """
    Representation of an agent's internal reasoning.

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning
            content was generated natively by a reasoner model or derived via prompting
        reasoning (str): The internal reasoning of the agent
        signature (Optional[str]): The model-generated signature of the reasoning step
    """

    message_type: Literal[MessageType.reasoning_message] = Field(MessageType.reasoning_message, description="The type of the message.")
    source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
    reasoning: str
    signature: Optional[str] = None


class HiddenReasoningMessage(LettaMessage):
    """
    Representation of an agent's internal reasoning where reasoning content
    has been hidden from the response.

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        state (Literal["redacted", "omitted"]): Whether the reasoning
            content was redacted by the provider or simply omitted by the API
        hidden_reasoning (Optional[str]): The internal reasoning of the agent
    """

    message_type: Literal[MessageType.hidden_reasoning_message] = Field(
        MessageType.hidden_reasoning_message, description="The type of the message."
    )
    state: Literal["redacted", "omitted"]
    hidden_reasoning: Optional[str] = None


class ToolCall(BaseModel):
    name: str
    arguments: str
    tool_call_id: str


class ToolCallDelta(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None
    tool_call_id: Optional[str] = None

    def model_dump(self, *args, **kwargs):
        """
        This is a workaround to exclude None values from the JSON dump since the
        OpenAI style of returning chunks doesn't include keys with null values.
        """
        kwargs["exclude_none"] = True
        return super().model_dump(*args, **kwargs)

    def json(self, *args, **kwargs):
        return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)


class ToolCallMessage(LettaMessage):
    """
    A message representing a request to call a tool (generated by the LLM to trigger tool execution).

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        tool_call (Union[ToolCall, ToolCallDelta]): The tool call
    """

    message_type: Literal[MessageType.tool_call_message] = Field(MessageType.tool_call_message, description="The type of the message.")
    tool_call: Union[ToolCall, ToolCallDelta]

    def model_dump(self, *args, **kwargs):
        """
        Handling for the ToolCallDelta exclude_none to work correctly
        """
        kwargs["exclude_none"] = True
        data = super().model_dump(*args, **kwargs)
        if isinstance(data["tool_call"], dict):
            data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
        return data

    class Config:
        json_encoders = {
            ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
            ToolCall: lambda v: v.model_dump(exclude_none=True),
        }

    @field_validator("tool_call", mode="before")
    @classmethod
    def validate_tool_call(cls, v):
        """
        Casts dicts into ToolCallMessage objects. Without this extra validator, Pydantic will throw
        an error if 'name' or 'arguments' are None instead of properly casting to ToolCallDelta
        instead of ToolCall.
        """
        if isinstance(v, dict):
            if "name" in v and "arguments" in v and "tool_call_id" in v:
                return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
            elif "name" in v or "arguments" in v or "tool_call_id" in v:
                return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
            else:
                raise ValueError("tool_call must contain either 'name' or 'arguments'")
        return v


class ToolReturnMessage(LettaMessage):
    """
    A message representing the return value of a tool call (generated by Letta executing the requested tool).

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        tool_return (str): The return value of the tool
        status (Literal["success", "error"]): The status of the tool call
        tool_call_id (str): A unique identifier for the tool call that generated this message
        stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
        stderr (Optional[List(str)]): Captured stderr from the tool invocation
    """

    message_type: Literal[MessageType.tool_return_message] = Field(MessageType.tool_return_message, description="The type of the message.")
    tool_return: str
    status: Literal["success", "error"]
    tool_call_id: str
    stdout: Optional[List[str]] = None
    stderr: Optional[List[str]] = None


class AssistantMessage(LettaMessage):
    """
    A message sent by the LLM in response to user input. Used in the LLM context.

    Args:
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        name (Optional[str]): The name of the sender of the message
        content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
    """

    message_type: Literal[MessageType.assistant_message] = Field(MessageType.assistant_message, description="The type of the message.")
    content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
        ...,
        description="The message content sent by the agent (can be a string or an array of content parts)",
        json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
    )


# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
LettaMessageUnion = Annotated[
    Union[SystemMessage, UserMessage, ReasoningMessage, HiddenReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
    Field(discriminator="message_type"),
]


def create_letta_message_union_schema():
    return {
        "oneOf": [
            {"$ref": "#/components/schemas/SystemMessage"},
            {"$ref": "#/components/schemas/UserMessage"},
            {"$ref": "#/components/schemas/ReasoningMessage"},
            {"$ref": "#/components/schemas/HiddenReasoningMessage"},
            {"$ref": "#/components/schemas/ToolCallMessage"},
            {"$ref": "#/components/schemas/ToolReturnMessage"},
            {"$ref": "#/components/schemas/AssistantMessage"},
        ],
        "discriminator": {
            "propertyName": "message_type",
            "mapping": {
                "system_message": "#/components/schemas/SystemMessage",
                "user_message": "#/components/schemas/UserMessage",
                "reasoning_message": "#/components/schemas/ReasoningMessage",
                "hidden_reasoning_message": "#/components/schemas/HiddenReasoningMessage",
                "tool_call_message": "#/components/schemas/ToolCallMessage",
                "tool_return_message": "#/components/schemas/ToolReturnMessage",
                "assistant_message": "#/components/schemas/AssistantMessage",
            },
        },
    }


# --------------------------
# Message Update API Schemas
# --------------------------


class UpdateSystemMessage(BaseModel):
    message_type: Literal["system_message"] = "system_message"
    content: str = Field(
        ..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)"
    )


class UpdateUserMessage(BaseModel):
    message_type: Literal["user_message"] = "user_message"
    content: Union[str, List[LettaUserMessageContentUnion]] = Field(
        ...,
        description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
        json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
    )


class UpdateReasoningMessage(BaseModel):
    reasoning: str
    message_type: Literal["reasoning_message"] = "reasoning_message"


class UpdateAssistantMessage(BaseModel):
    message_type: Literal["assistant_message"] = "assistant_message"
    content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
        ...,
        description="The message content sent by the assistant (can be a string or an array of content parts)",
        json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
    )


LettaMessageUpdateUnion = Annotated[
    Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
    Field(discriminator="message_type"),
]


# --------------------------
# Deprecated Message Schemas
# --------------------------


class LegacyFunctionCallMessage(LettaMessage):
    function_call: str


class LegacyFunctionReturn(LettaMessage):
    """
    A message representing the return value of a function call (generated by Letta executing the requested function).

    Args:
        function_return (str): The return value of the function
        status (Literal["success", "error"]): The status of the function call
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
        function_call_id (str): A unique identifier for the function call that generated this message
        stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
        stderr (Optional[List(str)]): Captured stderr from the function invocation
    """

    message_type: Literal["function_return"] = "function_return"
    function_return: str
    status: Literal["success", "error"]
    function_call_id: str
    stdout: Optional[List[str]] = None
    stderr: Optional[List[str]] = None


class LegacyInternalMonologue(LettaMessage):
    """
    Representation of an agent's internal monologue.

    Args:
        internal_monologue (str): The internal monologue of the agent
        id (str): The ID of the message
        date (datetime): The date the message was created in ISO format
    """

    message_type: Literal["internal_monologue"] = "internal_monologue"
    internal_monologue: str


LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
